# ggplot backend for the xgboost plotting facilities


#' @rdname xgb.plot.importance
#' @export
xgb.ggplot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL, 
                                  rel_to_first = FALSE, n_clusters = c(1:10), ...) {
  
  importance_matrix <- xgb.plot.importance(importance_matrix, top_n = top_n, measure = measure,
                                           rel_to_first = rel_to_first, plot = FALSE, ...)
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop("ggplot2 package is required", call. = FALSE)
  }
  if (!requireNamespace("Ckmeans.1d.dp", quietly = TRUE)) {
    stop("Ckmeans.1d.dp package is required", call. = FALSE)
  }
  
  clusters <- suppressWarnings(
    Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix$Importance, n_clusters)
  )
  importance_matrix[, Cluster := as.character(clusters$cluster)]

  plot <-
    ggplot2::ggplot(importance_matrix, 
                    ggplot2::aes(x = factor(Feature, levels = rev(Feature)), y = Importance, width = 0.05),
                    environment = environment()) + 
    ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position = "identity") + 
    ggplot2::coord_flip() + 
    ggplot2::xlab("Features") + 
    ggplot2::ggtitle("Feature importance") + 
    ggplot2::theme(plot.title = ggplot2::element_text(lineheight = .9, face = "bold"), 
                   panel.grid.major.y = ggplot2::element_blank())
  return(plot)
}


#' @rdname xgb.plot.deepness
#' @export
xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight")) {

  if (!requireNamespace("ggplot2", quietly = TRUE))
    stop("ggplot2 package is required for plotting the graph deepness.", call. = FALSE)

  which <- match.arg(which)
  
  dt_depths <- xgb.plot.deepness(model = model, plot = FALSE)
  dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
  setkey(dt_summaries, 'Depth')

  if (which == "2x1") {
    p1 <-
      ggplot2::ggplot(dt_summaries) +
      ggplot2::geom_bar(ggplot2::aes(x = Depth, y = N), stat = "Identity") +
      ggplot2::xlab("") +
      ggplot2::ylab("Number of leafs") +
      ggplot2::ggtitle("Model complexity") +
      ggplot2::theme(
        plot.title = ggplot2::element_text(lineheight = 0.9, face = "bold"),
        panel.grid.major.y = ggplot2::element_blank(),
        axis.ticks = ggplot2::element_blank(),
        axis.text.x = ggplot2::element_blank()
      )
  
    p2 <- 
      ggplot2::ggplot(dt_summaries) +
      ggplot2::geom_bar(ggplot2::aes(x = Depth, y = Cover), stat = "Identity") + 
      ggplot2::xlab("Leaf depth") +
      ggplot2::ylab("Weighted cover")
  
    multiplot(p1, p2, cols = 1)
    return(invisible(list(p1, p2)))
    
  } else if (which == "max.depth") {
    p <-
      ggplot2::ggplot(dt_depths[, max(Depth), Tree]) +
      ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
                           height = 0.15, alpha=0.4, size=3, stroke=0) + 
      ggplot2::xlab("tree #") +
      ggplot2::ylab("Max tree leaf depth")
    return(p)
    
  } else if (which == "med.depth") {
    p <-
      ggplot2::ggplot(dt_depths[, median(as.numeric(Depth)), Tree]) +
      ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
                           height = 0.15, alpha=0.4, size=3, stroke=0) + 
      ggplot2::xlab("tree #") +
      ggplot2::ylab("Median tree leaf depth")
    return(p)

  } else if (which == "med.weight") {
    p <-
      ggplot2::ggplot(dt_depths[, median(abs(Weight)), Tree]) +
      ggplot2::geom_point(ggplot2::aes(x = Tree, y = V1),
                          alpha=0.4, size=3, stroke=0) + 
      ggplot2::xlab("tree #") +
      ggplot2::ylab("Median absolute leaf weight")
    return(p)
  }
}

# Plot multiple ggplot graph aligned by rows and columns.
# ... the plots
# cols number of columns
# internal utility function
multiplot <- function(..., cols = 1) {
  plots <- list(...)
  num_plots = length(plots)
  
  layout <- matrix(seq(1, cols * ceiling(num_plots / cols)),
                   ncol = cols, nrow = ceiling(num_plots / cols))
  
  if (num_plots == 1) {
    print(plots[[1]])
  } else {
    grid::grid.newpage()
    grid::pushViewport(grid::viewport(layout = grid::grid.layout(nrow(layout), ncol(layout))))
    for (i in 1:num_plots) {
      # Get the i,j matrix positions of the regions that contain this subplot
      matchidx <- as.data.table(which(layout == i, arr.ind = TRUE))
      
      print(
        plots[[i]], vp = grid::viewport(
          layout.pos.row = matchidx$row,
          layout.pos.col = matchidx$col
        )
      )
    }
  }
}

globalVariables(c(
  "Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
  "element_blank", "element_text", "V1", "Weight"
))
